{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(2) # 产生伪随机值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "N_STATE = 6\n",
    "ACTIONS = [\"left\", \"right\"]\n",
    "EPSILON = 0.9 # 决策有多少概率会选用概率大的那个\n",
    "ALPHA = 0.1 # 学习效率, learning rate\n",
    "LAMBDA = 0.9 # 衰减度\n",
    "MAX_EXISODES = 13 # 玩13回合\n",
    "FRESH_TIME = 0.1 # 每0.3秒走一步"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "o----T\r",
      "-o---T"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:66: DeprecationWarning: \n",
      ".ix is deprecated. Please use\n",
      ".loc for label based indexing or\n",
      ".iloc for positional indexing\n",
      "\n",
      "See the documentation here:\n",
      "http://pandas.pydata.org/pandas-docs/stable/indexing.html#ix-indexer-is-deprecated\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                         "
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>left</th>\n",
       "      <th>right</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000103</td>\n",
       "      <td>0.005979</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.000004</td>\n",
       "      <td>0.034896</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.000007</td>\n",
       "      <td>0.146238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.016031</td>\n",
       "      <td>0.358074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.012253</td>\n",
       "      <td>0.745813</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       left     right\n",
       "0  0.000103  0.005979\n",
       "1  0.000004  0.034896\n",
       "2  0.000007  0.146238\n",
       "3  0.016031  0.358074\n",
       "4  0.012253  0.745813\n",
       "5  0.000000  0.000000"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def built_q_table(n_state, actions):\n",
    "    \n",
    "    table = pd.DataFrame(np.zeros((n_state, len(actions))), \n",
    "                        columns=actions)\n",
    "    return table\n",
    "\n",
    "def choose_action(state, q_table):\n",
    "    \n",
    "    state_actions = q_table.iloc[state, :]\n",
    "    if (np.random.uniform() > EPSILON) or (state_actions.all() == 0):\n",
    "        \n",
    "        action_name = np.random.choice(ACTIONS)\n",
    "    else:\n",
    "        action_name = state_actions.argmax()\n",
    "        \n",
    "    return action_name\n",
    "\n",
    "# 获取state\n",
    "def get_env_feedback(S, A):\n",
    "    \n",
    "    if A == \"right\":\n",
    "        if S == N_STATE - 2:\n",
    "            S_ = \"terminal\"\n",
    "            R = 1\n",
    "        else:\n",
    "            S_ = S + 1\n",
    "            R = 0\n",
    "    \n",
    "    else:\n",
    "        R = 0\n",
    "        if S == 0:\n",
    "            S_ = S\n",
    "        else:\n",
    "            S_ = S - 1\n",
    "    return S_, R\n",
    "\n",
    "# 最难的建立环境\n",
    "def update_env(S, episode, step_conter):\n",
    "    \n",
    "    env_list = ['-'] *(N_STATE - 1) + ['T'] # \"------T\"\n",
    "    if S == \"terminal\":\n",
    "        interaction = \"Episode %s:total_steps:%s\" %(episode+1, \n",
    "                                                   step_conter)\n",
    "        print(\"\\r{}\".format(interaction), end=\"\")\n",
    "        time.sleep(2)\n",
    "        print(\"\\r                         \", end=\"\")\n",
    "    \n",
    "    else:\n",
    "        env_list[S] = \"o\"\n",
    "        interaction = \"\".join(env_list)\n",
    "        print(\"\\r{}\".format(interaction), end=\"\")\n",
    "        time.sleep(FRESH_TIME)\n",
    "\n",
    "def rl():\n",
    "    q_table = built_q_table(N_STATE, ACTIONS)\n",
    "    for episode in range(MAX_EXISODES):\n",
    "        step_conter = 0\n",
    "        S = 0\n",
    "        is_terminated = False\n",
    "        update_env(S, episode, step_conter)\n",
    "        \n",
    "        while not is_terminated:\n",
    "            \n",
    "            A = choose_action(S, q_table)\n",
    "            S_, R = get_env_feedback(S, A)\n",
    "            q_predict = q_table.ix[S, A] # 估计值\n",
    "            \n",
    "            if S_ != \"terminal\":\n",
    "                q_target = R + LAMBDA * q_table.iloc[S_,:].max() # 真实值\n",
    "            else:\n",
    "                q_target = R\n",
    "                is_terminated = True\n",
    "                \n",
    "            q_table.ix[S, A] += ALPHA * (q_target - q_predict) #updateQ\n",
    "            S = S_\n",
    "            \n",
    "            update_env(S, episode, step_conter + 1)\n",
    "            step_conter += 1\n",
    "    return q_table\n",
    "rl()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
